File: ChatCompletion\ChatCompletion{T}.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI\Microsoft.Extensions.AI.csproj (Microsoft.Extensions.AI)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using Microsoft.Shared.Diagnostics;
 
namespace Microsoft.Extensions.AI;
 
/// <summary>Represents the result of a chat completion request with structured output.</summary>
/// <typeparam name="T">The type of value expected from the chat completion.</typeparam>
/// <remarks>
/// Language models are not guaranteed to honor the requested schema. If the model's output is not
/// parseable as the expected type, then <see cref="TryGetResult(out T)"/> will return <see langword="false"/>.
/// You can access the underlying JSON response on the <see cref="ChatCompletion.Message"/> property.
/// </remarks>
public class ChatCompletion<T> : ChatCompletion
{
    private static readonly JsonReaderOptions _allowMultipleValuesJsonReaderOptions = new JsonReaderOptions { AllowMultipleValues = true };
    private readonly JsonSerializerOptions _serializerOptions;
 
    private T? _deserializedResult;
    private bool _hasDeserializedResult;
 
    /// <summary>Initializes a new instance of the <see cref="ChatCompletion{T}"/> class.</summary>
    /// <param name="completion">The unstructured <see cref="ChatCompletion"/> that is being wrapped.</param>
    /// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/> to use when deserializing the result.</param>
    public ChatCompletion(ChatCompletion completion, JsonSerializerOptions serializerOptions)
        : base(Throw.IfNull(completion).Choices)
    {
        _serializerOptions = Throw.IfNull(serializerOptions);
        CompletionId = completion.CompletionId;
        ModelId = completion.ModelId;
        CreatedAt = completion.CreatedAt;
        FinishReason = completion.FinishReason;
        Usage = completion.Usage;
        RawRepresentation = completion.RawRepresentation;
        AdditionalProperties = completion.AdditionalProperties;
    }
 
    /// <summary>
    /// Gets the result of the chat completion as an instance of <typeparamref name="T"/>.
    /// </summary>
    /// <remarks>
    /// If the response did not contain JSON, or if deserialization fails, this property will throw.
    /// To avoid exceptions, use <see cref="TryGetResult(out T)"/> instead.
    /// </remarks>
    public T Result
    {
        get
        {
            var result = GetResultCore(out var failureReason);
            return failureReason switch
            {
                FailureReason.ResultDidNotContainJson => throw new InvalidOperationException("The response did not contain text to be deserialized"),
                FailureReason.DeserializationProducedNull => throw new InvalidOperationException("The deserialized response is null"),
                FailureReason.ResultDidNotContainDataProperty => throw new InvalidOperationException("The response did not contain the expected 'data' property"),
                _ => result!,
            };
        }
    }
 
    /// <summary>
    /// Attempts to deserialize the result to produce an instance of <typeparamref name="T"/>.
    /// </summary>
    /// <param name="result">When this method returns, contains the result.</param>
    /// <returns><see langword="true"/> if the result was produced, otherwise <see langword="false"/>.</returns>
    public bool TryGetResult([NotNullWhen(true)] out T? result)
    {
        try
        {
            result = GetResultCore(out var failureReason);
            return failureReason is null;
        }
#pragma warning disable CA1031 // Do not catch general exception types
        catch
        {
            result = default;
            return false;
        }
#pragma warning restore CA1031 // Do not catch general exception types
    }
 
    private static T? DeserializeFirstTopLevelObject(string json, JsonTypeInfo<T> typeInfo)
    {
        // We need to deserialize only the first top-level object as a workaround for a common LLM backend
        // issue. GPT 3.5 Turbo commonly returns multiple top-level objects after doing a function call.
        // See https://community.openai.com/t/2-json-objects-returned-when-using-function-calling-and-json-mode/574348
        var utf8ByteLength = Encoding.UTF8.GetByteCount(json);
        var buffer = ArrayPool<byte>.Shared.Rent(utf8ByteLength);
        try
        {
            var utf8SpanLength = Encoding.UTF8.GetBytes(json, 0, json.Length, buffer, 0);
            var utf8Span = new ReadOnlySpan<byte>(buffer, 0, utf8SpanLength);
            var reader = new Utf8JsonReader(utf8Span, _allowMultipleValuesJsonReaderOptions);
            return JsonSerializer.Deserialize(ref reader, typeInfo);
        }
        finally
        {
            ArrayPool<byte>.Shared.Return(buffer);
        }
    }
 
    /// <summary>
    /// Gets or sets a value indicating whether the JSON schema has an extra object wrapper.
    /// </summary>
    /// <remarks>
    /// The wrapper is required for any non-JSON-object-typed values such as numbers, enum values, and arrays.
    /// </remarks>
    internal bool IsWrappedInObject { get; set; }
 
    private string? GetResultAsJson()
    {
        var choice = Choices.Count == 1 ? Choices[0] : null;
        var content = choice?.Contents.Count == 1 ? choice.Contents[0] : null;
        return (content as TextContent)?.Text;
    }
 
    private T? GetResultCore(out FailureReason? failureReason)
    {
        if (_hasDeserializedResult)
        {
            failureReason = default;
            return _deserializedResult;
        }
 
        var json = GetResultAsJson();
        if (string.IsNullOrEmpty(json))
        {
            failureReason = FailureReason.ResultDidNotContainJson;
            return default;
        }
 
        T? deserialized = default;
 
        // If there's an exception here, we want it to propagate, since the Result property is meant to throw directly
 
        if (IsWrappedInObject)
        {
            if (JsonDocument.Parse(json!).RootElement.TryGetProperty("data", out var data))
            {
                json = data.GetRawText();
            }
            else
            {
                failureReason = FailureReason.ResultDidNotContainDataProperty;
                return default;
            }
        }
 
        deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<T>)_serializerOptions.GetTypeInfo(typeof(T)));
 
        if (deserialized is null)
        {
            failureReason = FailureReason.DeserializationProducedNull;
            return default;
        }
 
        _deserializedResult = deserialized;
        _hasDeserializedResult = true;
        failureReason = default;
        return deserialized;
    }
 
    private enum FailureReason
    {
        ResultDidNotContainJson,
        DeserializationProducedNull,
        ResultDidNotContainDataProperty,
    }
}